import sys
import random
from collections import defaultdict

import numpy as np
from allennlp.predictors import Predictor

from config import Config
from tools.device_manager import DeviceManager
from tools.saver import Saver
from tools.utils import pos_tags, write_json, read_json
from tools.time_counter import TimeCounter
from victim_model.text_classifier import TextClassifier
from victim_model.text_predictor import TextPredictor, TextPredictorEnsembler
from .attacker import BlackBoxAttacker
from .candidates import WordNetCandidate
from dataset import *


class PreUniversalAttack(BlackBoxAttacker):
    def __init__(self, cf: Config, predictor: Predictor):
        super(PreUniversalAttack, self).__init__(cf, predictor)
        self.synonym_candidate = WordNetCandidate(self.supported_postag)
        self.unk_candidate = '@@UNKNOWN@@'
        self.texts = []
        self.pairs = []

    def set_texts(self, texts):
        for text in texts:
            text['sentence'] = [t.text for t in self.tokenizer.tokenize(text['sentence'])]
            text['tag'] = pos_tags([t for t in text['sentence']])
            self.texts.append(text)

    def get_pairs(self):
        self.pairs = []
        for i, text in enumerate(self.texts):
            if not self.cf.quiet:
                sys.stdout.write(f'\r{i}')
            pairs = self.get_synonym_saliency(text)
            self.pairs.extend(pairs)

        self.init_universal_candidates()
        write_json(self.cf.p_adv['pre_universal'], self.pairs)
        # self.pairs = read_json(self.cf.p_adv['pre_universal'])

    def get_synonym_saliency(self, text):
        candidates = {'input': [text], 'pos': [], 'synonym': []}

        for i, (word, tag) in enumerate(zip(text['sentence'], text['tag'])):
            synonyms = self.synonym_candidate.candidate_set(word, tag)
            for w in synonyms:
                # FIXME
                candidates['input'].append(self.subsitude(text, i, w))
                candidates['pos'].append(i)
                candidates['synonym'].append(w)

        outputs = self.predict_batch_data(candidates['input'])

        pairs = []
        all_logits = [o['logits'] for o in outputs]

        gold_idx = outputs[0]['gold']

        logit_x, logit_candidates = all_logits[0], all_logits[1:]

        for i, logit_candidate, word in zip(candidates['pos'], logit_candidates, candidates['synonym']):
            scores = []
            for j, (l_p, l_g) in enumerate(zip(logit_candidate, logit_x)):
                if abs(j - gold_idx) < 1e-3:
                    scores.append(l_g - l_p)
                else:
                    scores.append(l_p - l_g)

            pair = (text['label'], text['sentence'][i], word, np.average(scores))
            pairs.append(pair)

        return pairs

    def init_universal_candidates(self):
        pairs = dict()
        for gold, src, tgt, p in self.pairs:
            if gold not in pairs:
                pairs[gold] = dict()
            if src not in pairs[gold]:
                pairs[gold][src] = dict()
            if tgt not in pairs[gold][src]:
                pairs[gold][src][tgt] = []
            pairs[gold][src][tgt].append(p)
        best_pair = dict()
        for gold, v1 in pairs.items():
            best_pair[gold] = dict()
            for src, v2 in v1.items():
                tgt_p = []
                for tgt, ps in v2.items():
                    p = np.average(ps)
                    tgt_p.append((tgt, p))

                best_pair[gold][src] = max(tgt_p, key=lambda i: i[1])

        self.pairs = best_pair

    def get_victim_substitute_pair(self, text):

        H = []
        for i, w in enumerate(text['sentence']):
            if w in self.pairs[text['label']]:
                H.append((i, *self.pairs[text['label']][w]))

        attack_num = self.attack_num(len(text['sentence']))

        H.sort(key=lambda i: i[2], reverse=True)
        H = H[:attack_num]
        H = [(h[0], h[1]) for h in H]

        return H

    def generate_adv_examples(self):
        all_adv_examples = []

        for i, text in enumerate(self.texts):
            if not self.cf.quiet:
                sys.stdout.write(f'\r{i}')

            H = self.get_victim_substitute_pair(text)
            text['tag'] = []

            adv_examples = [self.copy_text(text)]

            for p, w in H:
                adv_examples.append(self.subsitude(adv_examples[-1], p, w))

            all_adv_examples.append(adv_examples)

        write_json(self.cf.p_adv['universal'], all_adv_examples)


class UniversalAttack(BlackBoxAttacker):
    def __init__(self, cf: Config, predictor: Predictor):
        super(UniversalAttack, self).__init__(cf, predictor)
        self.all_adv_examples = read_json(self.cf.p_adv['universal'])
        self.result = self.attack()
        self.count = 0

    def forward(self, text):
        ret = self.result[self.count]
        self.count += 1
        return ret

    def attack(self):
        results = []
        for i in range(0, len(self.all_adv_examples), self.cf.batch_size):

            if not self.cf.quiet:
                sys.stdout.write(f'\r{i}')
            candidates = {'i': [], 'text': []}
            for j in range(self.cf.batch_size):
                if i + j >= len(self.all_adv_examples):
                    break
                for text in self.all_adv_examples[i + j]:
                    candidates['i'].append(j)
                    candidates['text'].append(text)
            outputs = self.predict_batch_data(candidates['text'])
            i2output = defaultdict(list)
            for j, o in zip(candidates['i'], outputs):
                i2output[j].append(o)
            for j, o in i2output.items():
                ret = self.get_output(o, self.all_adv_examples[i + j])
                results.append(ret)
        return results

    def get_output(self, outputs, texts):

        stop = np.nonzero(self.stop(outputs))[0]
        length = stop[0] if len(stop) > 0 else len(outputs) - 1
        adv_text = texts[length]
        ratio = length / len(adv_text['sentence'])
        adv_text['sentence'] = self.tokenizer.detokenize(adv_text['sentence'])
        del adv_text['tag']
        gold = outputs[0]['gold']
        pred = outputs[length]['pred']
        success = bool(abs(gold - pred) > 1e-3)

        return self.attack_result(success=success,
                                  length=ratio,
                                  adv_example=adv_text)


cf = Config()


def build_data_reader(cf):
    if cf.dataset == 'imdb':
        Reader = IMDBDatasetReader
    elif cf.dataset == 'agnews':
        Reader = AGNewsDatasetReader
    elif cf.dataset == 'mr':
        Reader = MRDatasetReader
    else:
        raise ValueError(f'{cf.dataset} not implemented. Only support: imdb and agnews.')
    return Reader


def build_predictor(Reader, cf):
    encoder = cf.encoder.split(',')
    token = cf.token.split(',')

    predictors = []
    for e, t in zip(encoder, token):
        reader = Reader(cf, token_type=t)
        saver = Saver(f'{cf.dataset}_{e}_{t}')
        model = saver.load_last_epoch(TextClassifier, {'cf': cf, 'encoder_type': e, 'token_type': t})
        model = model.cuda()
        f = TextPredictor(model, reader)
        predictors.append(f)

    if len(predictors) == 1:
        predictor = predictors[0]
    else:
        predictor = TextPredictorEnsembler(predictors)
    return predictor


def main():
    cf = Config()
    with DeviceManager(cf.device):
        Reader = build_data_reader(cf)
        predictor = build_predictor(Reader, cf)
        pre_attacker = PreUniversalAttack(cf, predictor)

        reader = Reader(cf, token_type='word')
        victim_dataset = reader.read_json(cf.p_split['test'])

        pre_attacker.set_texts(victim_dataset)
        pre_attacker.get_pairs()
        pre_attacker.generate_adv_examples()
